Skip to content

[ROCm] Enable inference quantization tests on ROCm (Float8, Int8, per-token)#4044

Open
brucechanglongxu wants to merge 1 commit intopytorch:mainfrom
brucechanglongxu:rocm-enable-inference-quantization-tests
Open

[ROCm] Enable inference quantization tests on ROCm (Float8, Int8, per-token)#4044
brucechanglongxu wants to merge 1 commit intopytorch:mainfrom
brucechanglongxu:rocm-enable-inference-quantization-tests

Conversation

@brucechanglongxu
Copy link
Contributor

@brucechanglongxu brucechanglongxu commented Mar 11, 2026

Removes blanket @skip_if_rocm("ROCm enablement in progress") from inference quantization tests that already pass on MI300X (gfx942). These were skipped during the initial ROCm bringup but the underlying quantization configs work fine now.

test_workflow_e2e_numerics in test_quant_api.py runs end-to-end quantize + inference + SQNR checks for several configs. On ROCm, Float8WeightOnly, Int8DynActInt8Weight, and Int8WeightOnly all pass. GemliteUIntX is already gated by has_gemlite. Float8DynActFloat8Weight with default PerTensor granularity is skipped -- see note below.

test_per_token_linear_cuda in test_integration.py tests _quant_int8_dynamic_per_token_linear on GPU across float32/float16/bfloat16. Passes on MI300X with SQNR >= 39 on all dtypes.

test_flatten_unflatten in test_affine_quantized.py tests __tensor_flatten__ / __tensor_unflatten__ roundtrip on Int8 quantized tensors. Passes on both CPU and CUDA on ROCm.

Tests that remain skipped with updated reasons:

  • test_print_quantized_module: torch.backends.cusparselt.is_available() returns True on this ROCm machine (hipSPARSELt backend detected) but SemiSparseLayout fails at runtime with "hipSPARSELt not supported on your machine". Updated the skip message to reflect the actual blocker.
  • test_int4_weight_only_quant_subclass_api_grouped: see note below.
  • test_int8_weight_only_quant_with_freeze: unchanged, marked flaky.

Two bugs discovered during this investigation (not ROCm-specific but surfaced here because the mslk codepath is unavailable on AMD):

  1. _is_128_128_scaled false positive on per-tensor scaled tensors: when a weight tensor happens to be exactly 128x128, PerTensor granularity produces block_size=[128, 128]. _is_128_128_scaled checks b[0] == 128 and b[1] == 128 and returns True, even though _is_tensorwise_scaled also returns True. The torch kernel path in float8_tensor.py then hits elif _is_128_128_scaled(weight_tensor) before checking tensorwise, and asserts that the input must be _is_1_128_scaled, which fails. On CUDA with SM90+ the mslk path is taken instead so this never triggers. The fix would be to check _is_tensorwise_scaled before _is_128_128_scaled in the dispatch chain, or to have _is_128_128_scaled exclude the tensorwise case. Float8DynActFloat8Weight with PerRow granularity works fine on ROCm (SQNR=28.9).

  2. _weight_int4pack_mm assertion on small output dimensions: test_int4_weight_only_quant_subclass_api_grouped uses test shapes with N=16 and N=8. The tinygemm kernel (_weight_int4pack_mm) asserts qScaleAndZeros.size(1) == n and fails on these shapes. Int4WeightOnly with tile_packed_to_4d works on normal-sized shapes (tested 128x128 with group_size=32, SQNR=24.4). This is likely a pre-existing constraint of the tinygemm packing format on ROCm rather than something introduced by these changes.

Tested on MI300X (gfx942) in a ROCm Docker container.

…-token)

Remove blanket @skip_if_rocm from tests that already pass on MI300X:
- test_workflow_e2e_numerics: Float8WeightOnly, Int8DynActInt8Weight,
  Int8WeightOnly all pass. Float8DynActFloat8Weight with default
  PerTensor is skipped due to a _is_128_128_scaled false positive when
  the linear is exactly 128x128 (block_size coincides with shape);
  PerRow granularity works fine.
- test_per_token_linear_cuda: per-token int8 dynamic quantization
  works on ROCm across float32/float16/bfloat16.
- test_flatten_unflatten: Int8 quantized tensor flatten/unflatten
  roundtrip works on ROCm.

Remaining skips with updated reasons:
- test_print_quantized_module: hipSPARSELt reports available via
  torch.backends but fails at runtime on MI300X.
- test_int4_weight_only_quant_subclass_api_grouped: _weight_int4pack_mm
  hits a qScaleAndZeros size assertion on the small N=16/N=8 shapes
  used in this test.
- test_int8_weight_only_quant_with_freeze: kept as-is (flaky).
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 11, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4044

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: rocm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant